name: huggingface

resources:
  accelerators: V100
  # The above is a shorthand for <name>: <count=1>.  Same as:
  # accelerators: V100:1

# The setup command.  Will be run under the working directory.
setup: |
  git clone https://github.com/huggingface/transformers/
  # checkout to the correct version
  cd transformers
  git checkout v4.25.1
  pip3 install .
  cd examples/pytorch/text-classification
  # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5).
  pip3 install -r requirements.txt tensorboard torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113

# The command to run.  Will be run under the working directory.
run: |
  cd transformers/examples/pytorch/text-classification
  python3 run_glue.py \
    --model_name_or_path bert-base-cased \
    --dataset_name imdb  \
    --do_train \
    --max_seq_length 128 \
    --per_device_train_batch_size 32 \
    --learning_rate 2e-5 \
    --max_steps 50 \
    --output_dir /tmp/imdb/ --overwrite_output_dir \
    --fp16
